#!/usr/bin/python
# -*- coding: utf-8 -*-
# Eli Criffield < python AT zendo dot net >
#
# TODO: stderr is mixed in with stdout
# TODO: cmd line options should be more standard, and parsed by
#       ConchOptions()


''' This will connect to a host and run a list of commands
Its different from just a normal connection because it launches a shell
and waits for the prompt, then sends the command, waits for prompt sends
command ect.. sshCmds() will return the output of all the commands it was
given together.

You probably want to run this on the command line or use sshCmds()
on the command line run:
conchClient.py host comand1 command2

in your python script use:

from conchClient import sshCmds
(out,err) = sshCmds(host,listOfCmds,user,echo)
print out

host is required
listOfCmds is any iterator that returns the commands
        to be run (uswally a list)
user can be None, then the environment LOGNAME will be used
echo must be True or False
'''

__version__ = 0.1185457318
__author__ = 'Eli Criffield <ecriffield at fnni.com>'


import os
import re
import sys
import struct
import base64
from twisted.conch import error
from twisted.python import log
from twisted.conch.client import default, options
from twisted.conch.ssh import channel, common, connection, keys, session, \
    transport, userauth
from twisted.internet import defer, protocol, reactor, stdio

# Regular expressions
rxPrompt = re.compile('[\\>\\#\\$] $', re.M)
rxMore = re.compile('--More--', re.M)


class SSHClientFactory(protocol.ClientFactory):

    def __init__(self, cmdlist, host, user, outputer):
        self.host = host
        self.user = user
        self.cmdlist = cmdlist
        self.cmdlist.reverse()
        self.outputer = outputer

    def stopFactory(self):
        try:
            reactor.stop()
        except:
            pass

    def buildProtocol(self, addr):
        clientTransport = SSHClientTransport(cmdlist=self.cmdlist,
                outputer=self.outputer)
        clientTransport.user = self.user
        clientTransport.host = self.host
        return clientTransport

    def clientConnectionFailed(self, connector, reason):
        print 'Connection failed: %s' % reason

    def clientConnectionLost(self, connector, reason):
        pass

class ClientOptions(options.ConchOptions):
    identitys = ['~/.ssh/id_rsa', '~/.ssh/id_dsa']
    pass


class SSHClientUserAuth(default.SSHUserAuthClient):

    def __init__(self, user, *args):
        userauth.SSHUserAuthClient.__init__(self, user, *args)
        self.keyAgent = None
        self.options = ClientOptions()
        self.usedFiles = []


class SSHClientConnection(connection.SSHConnection):

    def __init__(self, cmdlist, outputer, *args, **kwargs):
        connection.SSHConnection.__init__(self, *args, **kwargs)
        self.cmdlist = cmdlist
        self.outputer = outputer

    def serviceStarted(self):
        log.msg('Opening command channel')
        self.openChannel(SSHCommandChannel(cmdlist=self.cmdlist,
                         outputer=self.outputer, conn=self))


class SSHClientTransport(transport.SSHClientTransport):

    def __init__(self, cmdlist, outputer):
        self.cmdlist = cmdlist
        self.outputer = outputer

    def receiveError(self, code, desc):
        print 'disconnected error %i: %s' % (code, desc)

    def sendDisconnect(self, code, reason):
        print 'disconnect error %i: %s' % (code, reason)
        transport.SSHClientTransport.sendDisconnect(self, code, reason)

    def verifyHostKey(self, pubKey, fingerprint):
        host = self.host
        goodKey = default.isInKnownHosts(host, pubKey, {'known-hosts': None})
        if goodKey == 1:  # good key
            return defer.succeed(1)
        elif goodKey == 2: # AAHHHHH changed
            return defer.fail(ConchError('changed host key'))
        else:
            (oldout, oldin) = (sys.stdout, sys.stdin)
            sys.stdin = sys.stdout = open('/dev/tty', 'r+')
            if host == self.transport.getPeer().host:
                khHost = host
            else:
                host = '%s (%s)' % (host, self.transport.getPeer().host)
                khHost = '%s,%s' % (host, self.transport.getPeer().host)
            keyType = common.getNS(pubKey)[0]
            print "The authenticity of host '%s' can't beestablished.\n    %s key fingerprint is %s." % \
                (host, {'ssh-dss': 'DSA', 'ssh-rsa': 'RSA'}[keyType],
                 fingerprint)
            try:
                ans = raw_input('Are you sure you want to continueconnecting (yes/no)? ')
            except KeyboardInterrupt:
                return defer.fail(ConchError('^C'))
            while ans.lower() not in ('yes', 'no'):
                ans = raw_input("Please type 'yes' or 'no': ")
            (sys.stdout, sys.stdin) = (oldout, oldin)
            if ans == 'no':
                print 'Host key verification failed.'
                return defer.fail(ConchError('bad host key'))
            print "Warning: Permanently added '%s' (%s) to the list ofknown hosts." % \
                (khHost, {'ssh-dss': 'DSA', 'ssh-rsa': 'RSA'}[keyType])
            known_hosts = open(os.path.expanduser('~/.ssh/known_hosts'),
                               'r+')
            known_hosts.seek(-1, 2)
            if known_hosts.read(1) != '\n':
                known_hosts.write('\n')
            encodedKey = base64.encodestring(pubKey).replace('\n', '')
            known_hosts.write('%s %s %s\n' % (khHost, keyType,
                              encodedKey))
            known_hosts.close()
            return defer.succeed(1)

    def connectionSecure(self):
        log.msg('Securing connection')
        clientConnection = SSHClientConnection(cmdlist=self.cmdlist,
                outputer=self.outputer)
        self.requestService(SSHClientUserAuth(self.user,
                            clientConnection))


class SSHCommandChannel(channel.SSHChannel):

    name = 'session'

    def __init__(self, cmdlist, outputer, *args, **kwargs):
        channel.SSHChannel.__init__(self, *args, **kwargs)
        self.cmdlist = cmdlist
        self.outputer = outputer

    def openFailed(self, reason):
        print 'channel open failed: %s' % reason

    def channelOpen(self, data):
        log.msg('Channel open')
        term = 'ansi'
        winsz = struct.pack('4H', 80, 25, 80, 25)
        winSize = struct.unpack('4H', winsz)
        ptyReqData = session.packRequest_pty_req(term, winSize, '')
        self.conn.sendRequest(self, 'pty-req', ptyReqData)
        self.conn.sendRequest(self, 'shell', '')
        self.data = ''

    def dataReceived(self, data):
        self.data += data
        self.outputer.moreStdout(data)

        promptMatch = rxPrompt.search(self.data)
        if promptMatch:
            self.data = ''
            if len(self.cmdlist) > 0:
                cmd = self.cmdlist.pop()
                log.msg('execing: %s'%cmd)
                self.write(cmd)
                self.write('\n')
            else:
                self.loseConnection()
        # moreMatch is needed for cisco's
        # Since we have a terminal, it pauses with a more
        # at each screen, hit space and it will continue
        moreMatch = rxMore.search(self.data)
        if moreMatch:
            self.data = ''
            #hit space
            self.write(' ')

    def extReceived(self, t, data):
        if t == connection.EXTENDED_DATA_STDERR:
            sys.stderr.write(data)

    def eofReceived(self):
        log.msg('Received EOF')
        self.loseConnection()

    def closeReceived(self):
        log.msg('Remote side closed')
        self.conn.sendClose(self)

    def closed(self):
        log.msg('Channel closed')
        reactor.stop()
        log.msg('Reactor Stoped')
        return None

    def request_exit_status(self, data):
        global exitStatus
        exitStatus = int(struct.unpack('>L', data)[0])
        log.msg('Exit status: %s' % exitStatus)

    def sendEOF(self):
        self.conn.sendEOF(self)

    def stopWriting(self):
        pass

    def startWriting(self):
        pass


class saveOutput:

    def __init__(self, echo=False):
        self.stdout = ''
        self.stderr = ''
        self.echo = echo

    def moreStdout(self, data):
        if self.echo:
            sys.stdout.write(data)
            sys.stdout.flush()
        self.stdout += data

    def moreStderr(self, data):
        if self.echo:
            sys.stderr.write(data)
            sys.stderr.flush()
        self.stderr += data


def __sshCmds(host, cmds, user=None, echo=False):
    outputer = saveOutput(echo=echo)
    clientFactory = SSHClientFactory(cmdlist=cmds, host=host, user=user,
            outputer=outputer)
    log.msg('Connecting to:%s'%host)
    reactor.connectTCP(host, 22, clientFactory)
    log.msg('Starting reactor')
    reactor.run()
    log.msg('Reactor has ended')
    return (outputer.stdout, outputer.stderr)


def sshCmds(host, cmds, user=None, echo=False):
    '''
    from conchClient import sshCmds
    (out,err) = sshCmds(host,listOfCmds,user,echo)
    print out

    host is required
    listOfCmds is any interator that returns the commands
            to be run (uswally a list)
    user can be None, then the envorment LOGNAME will be used
    echo must be True or False
    '''
    if user == None:
        user = os.environ.get('LOGNAME')

    # We have to fork because running the reactor twice doesn't work well
    (rPipe, wPipe) = os.pipe()  # these are file descriptors, not file objects
    pid = os.fork()
    if pid:
        # I'm the Parent
        os.close(wPipe)
        rfd = os.fdopen(rPipe)
        outpt = rfd.read()
        os.waitpid(pid, 0)  # make sure the child process gets cleaned up
        # FIXME, stderr is mixed in with stdout
        fakeerr = ''
        return (outpt, fakeerr)
    else:

        # I'm the Child
        os.close(rPipe)
        wfd = os.fdopen(wPipe, 'w')
        (stdout, stderr) = __sshCmds(host, cmds, user, echo)
        # FIXME Stderr is mixed with stdout
        wfd.write(stdout)
        wfd.close()
        sys.exit(0)


if __name__ == '__main__':

    if len(sys.argv) < 3:
        print '%s [user@]host cmd1 [cmd2 cmd3]'
        sys.exit(1)
    if (sys.argv)[1].find('@') != -1:
        (user, host) = (sys.argv)[1].split('@', 1)
    else:
        user = None
        host = (sys.argv)[1]
    cmds = (sys.argv)[2:]

    (out, err) = sshCmds(host, cmds, user, True)
    print '\n'
